from turtle import color
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib import rc
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
import os
import pandas as pd
import seaborn as sns
import argparse


s = 20
rc_ = {'figure.figsize':(8,8),'axes.labelsize': 30, 'xtick.labelsize': s, 
        'ytick.labelsize': s, 'legend.fontsize': 20}
sns.set(rc=rc_, style="darkgrid")
cblue = sns.color_palette("colorblind")[0]
cgreen = sns.color_palette("colorblind")[1]
cred = sns.color_palette("colorblind")[2]
# rc('text', usetex=True)

parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='./images',
    help="path"
)
args = parser.parse_args()

# Vary p [0 1], Vary n [0 2] in p^n, Vary rmin [-5 0) 

# #####################################################################################
rmin, rmax = -1, 0
n = 1000
convergences = np.zeros((n,n))
successes = np.zeros((n,n))
minmax_line = [[],[]]
dc_minmax_line = [[],[]]
d_minmax_line = [[],[]]
c_minmax_line = [[],[]]
penalty = np.linspace(-10,0,n)
reward = np.linspace(rmax,rmin,n)
penalty[-1] = penalty[-2]
reward[-1] = reward[-2]
p = 0.4

r_minmax = -2.63 # p1=p2=0.4
r_DC = -2/(1-p)
r_D = -2
# r_minmax = -2.63 # p1=0
# r_DC = r_minmax
# r_D = r_minmax
# r_minmax = -2 # p2=0
# r_DC = -2.73
# r_D = -1.62

"""
delta_p_s0 = (1-p) #- p
delta_p_s0[0] = 0
delta_p_s0_ = p #- (1-p)
delta_p_s0_[0] = 0
delta_p_s0_a = np.max([delta_p_s0,delta_p_s0_], axis=0)
delta_p_s0_b = np.min([delta_p_s0,delta_p_s0_], axis=0)
delta_p_s0_c = delta_p_s0_a + delta_p_s0_b
delta_p_s2 = (1-p) - (1-p)
C = delta_p_s0_a

p = p[1:]
C = C[1:]

P = np.zeros((p.shape[0], 4, 4, 2)) # p, S, S, A
P[:,3,3,:] = 1.0
P[:,1,1,:] = 1.0
P[:,2,2,0] = p
P[:,2,2,1] = p
P[:,2,3,0] = 1-p
P[:,2,3,1] = 1-p
P[:,0,2,0] = p
P[:,0,2,1] = 1-p
P[:,0,1,0] = 1-p
P[:,0,1,1] = p
R = np.ones((p.shape[0], 4, 4, 2)) # p, S, A
R[:,[1,3],:,:] = 0.0
V = np.zeros((p.shape[0], 4)) # p, S
while True:
    V_pre = V.copy()
    for s in range(4):
        V[:,s] = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
    if np.abs(V_pre-V).max() <= 0:
        break
D = V.max(axis=1)

a = (rmin) + np.zeros(p.shape[0])
b = (rmin-rmax)*(D/C)
penaltyR = np.min([a, b], axis=0)
b = (rmin-rmax)*D
penaltyD = np.min([a, b], axis=0)
b = (rmin-rmax)*(1/C)
penaltyC = np.min([a, b], axis=0)
"""
# probs = np.linspace(1,0,n)
for ip, r in enumerate(reward): #[0, 0.1, 0.25]:
    if ip%100==0:
        print(r)
    """
    # D, C = 2, max([(1-p)-p,p-(1-p)]) # 0.72
    if p!=1:
        mm = penaltyR[ip-1]
        idx = np.abs(penalty - mm).argmin()
        if mm>=penalty[0] and mm<=penalty[-1]:
            dc_minmax_line[0].append(ip)
            dc_minmax_line[1].append(idx)

        mm = penaltyD[ip-1]
        idx = np.abs(penalty - mm).argmin()
        if mm>=penalty[0] and mm<=penalty[-1]:
            d_minmax_line[0].append(ip)
            d_minmax_line[1].append(idx)

        mm = penaltyC[ip-1]
        idx = np.abs(penalty - mm).argmin()
        if mm>=penalty[0] and mm<=penalty[-1]:
            c_minmax_line[0].append(ip)
            c_minmax_line[1].append(idx)
    """
    states = 4
    P = np.zeros((penalty.shape[0], states, states, 2)) # p, S, S, A
    P[:,3,3,:] = 1.0
    P[:,1,1,:] = 1.0
    P[:,2,2,:] = np.array([p,p])
    P[:,2,3,:] = np.array([1-p,1-p])
    P[:,0,1,:] = np.array([1-p,p])
    P[:,0,2,:] = np.array([p,1-p])
    R = r*np.ones((penalty.shape[0], states, states, 2)) # p, S, S, A
    R[:,[1,3],:,:] = 0.0
    R[:,0,1,0] = penalty
    R[:,0,1,1] = penalty
    Q = rmin+np.zeros((penalty.shape[0], states, 2)) # p, S
    pi = np.zeros((penalty.shape[0], states)) # p, S

    step=0
    maxstep = 10000
    convergence = np.zeros(penalty.shape[0])
    while True:
        step+=1
        Q_pre = Q.copy()
        for s in range(states):
            Qs = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + Q[:,s_].max(axis=1)) for s_ in range(states)]).sum(axis=0)) for a in range(2)])
            Q[:,s,:] = (Q[:,s,:] + (Qs.T-Q[:,s,:])) # Vs
        # for i in range(penalty.shape[0]): 
        #     if np.abs(Q_pre[i]-Q[i]).max() <= 1e-5 and convergence[i] == 0:
        #         convergence[i] = step
        #     if step>maxstep and convergence[i] == 0:
        #         convergence[i] = maxstep
        if np.abs(Q_pre-Q).max() <= 1e-10 or step>maxstep:
            break
    #convergence = (convergence - convergence.min())/(convergence.max() - convergence.min())
    #convergences.append(convergence)
    convergences[ip,:] = convergence
    # print("step",step)
    # print("convergence",convergence)
    
    success = np.zeros(penalty.shape[0])
    for i in range(penalty.shape[0]):
        if Q[i,0].argmax() == 1:
            success[i] = (1-p)
        else:
            success[i] = p
        if i>0 and success[i] != success[i-1]:
            minmax_line[0].append(ip)
            minmax_line[1].append(i)
    successes[ip,:] = success
# convergences = np.array(convergences)
minmax_line = np.array(minmax_line)
dc_minmax_line = np.array(dc_minmax_line)
d_minmax_line = np.array(d_minmax_line)
c_minmax_line = np.array(c_minmax_line)
successes = np.flipud((np.rot90(successes,k=3,axes=(0,1))))
# #####################################################################################

cim = plt.imread("./images/cmap_1.png")
cim = cim[cim.shape[0]//2, :, :]
cmap = mcolors.ListedColormap(cim)
# cmap="RdYlBu_r"


# print(list(zip(probs[minmax_line[0]],penalty[minmax_line[1]])))
print("Plotting")
fig = plt.figure(dpi=60)

#plt.plot(np.arange(0,len(minmax_line[0])),n-minmax_line[1][::-1], color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
plt.axhline(y=n-np.abs(penalty - r_minmax).argmin(), color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])

c = plt.imshow(1-successes, cmap=cmap, vmin=0, vmax=1) # cmap="RdYlBu_r"
fig.colorbar(c,fraction=0.045)
plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
plt.xlabel(r"$R_{step} \in [-1 ~ 0]$")
plt.grid(False)
plt.xticks([])
plt.yticks([])
# plt.xticks(range(len(penalty)),penalty)
# plt.yticks(range(len(probs)),probs)
fig.tight_layout()
plt.savefig("{}/{}.pdf".format(args.path,f"reward_vs_penalty_p"), bbox_inches='tight')
# plt.show()

print("Plotting")
fig = plt.figure(dpi=60)

#plt.plot(np.arange(0,len(minmax_line[0])),n-minmax_line[1][::-1], color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
#plt.plot(np.arange(0,len(dc_minmax_line[0])),n-dc_minmax_line[1][::-1], color="blue", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
#plt.plot(np.arange(0,len(d_minmax_line[0])),n-d_minmax_line[1][::-1], color="red", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])

plt.axhline(y=n-np.abs(penalty - r_minmax).argmin(), color="black", label=r'$R_{Minmax}$', linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
plt.axhline(y=n-np.abs(penalty - r_DC).argmin(), color="blue", linestyle="-", label=r'$\bar R_{MIN}$', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
plt.axhline(y=n-np.abs(penalty - r_D).argmin(), color="red", linestyle="-", label=r'$\bar R_{MAX}$', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])

c = plt.imshow(1-successes, cmap=cmap, vmin=0, vmax=1) # cmap="RdYlBu_r"
fig.colorbar(c,fraction=0.045)
plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
plt.xlabel(r"$R_{step} \in [-1 ~ 0]$")
plt.grid(False)
plt.xticks([])
plt.yticks([])
legend = plt.legend(loc='lower left', labelcolor='white', fancybox=True, framealpha=0.35, frameon=True)
legend.get_frame().set_facecolor((0, 0, 0, 1))
# plt.xticks(range(len(penalty)),penalty)
# plt.yticks(range(len(probs)),probs)
fig.tight_layout()
plt.savefig("{}/{}.pdf".format(args.path,f"reward_vs_penalty_bounds_p"), bbox_inches='tight')
# plt.show()

print("Plotting")
fig = plt.figure(dpi=60)

# #plt.plot(np.arange(0,len(minmax_line[0])),n-minmax_line[1][::-1], color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
# #plt.axhline(y=n-np.abs(penalty - r_minmax).argmin(), color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
# print(convergences)
# c = plt.imshow(convergences, cmap=cmap) # cmap="RdYlBu_r"
# fig.colorbar(c,fraction=0.045)
# plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
# plt.xlabel(r"$R_{step} \in [-1 ~ 0]$")
# plt.grid(False)
# plt.xticks([])
# plt.yticks([])
# # plt.xticks(range(len(penalty)),penalty)
# # plt.yticks(range(len(probs)),probs)
# fig.tight_layout()
# plt.savefig("{}/{}.pdf".format(args.path,f"reward_vs_penalty_steps"), bbox_inches='tight')
# # plt.show()
